Ames Mutagenicity

Dataset Description : The mutagenic effect has a close relationship with the carcinogenicity. Nowadays, the most widely used assay for testing the mutagenicity of compounds is the Ames experiment which was invented by a professor named Ames. The Ames test is a short-term bacterial reverse mutation assay detecting a large number of compounds which can induce genetic damage and frameshift mutations. The dataset is aggregated from four papers.

Task Description: Binary classification. Given a drug SMILES string, predict whether it is mutagenic (1) or not mutagenic (0).

Imports

In [2]:
from forgebox.imports import *
from gc_utils.config import ObjDict
from plotly import express as px
from transformers import AutoModel, AutoTokenizer
import pytorch_lightning as pl
from gc_utils import DBs

Config

In [3]:
config = ObjDict(
    pretrained = "seyonec/ChemBERTA_PubChem1M_shard00_155k",
    bs = 32,
    versions={
        pl.__name__:pl.__version__,
        torch.__name__:torch.__version__
    }
)

config
Out[3]:
{'pretrained': 'seyonec/ChemBERTA_PubChem1M_shard00_155k',
 'bs': 32,
 'versions': {'pytorch_lightning': '1.1.2', 'torch': '1.5.0'}}

Download data

In [3]:
from tdc.single_pred import Tox
data = Tox(name = 'AMES')
split = data.get_split()
Found local copy...
Loading...
Done!
In [4]:
def get_split(split):
    """
    return train, valid, test dataframe
    """
    return split["train"] ,split["valid"] ,split["test"]
In [5]:
train, valid, test  = get_split(split)
In [6]:
for df in [train, valid, test]:
    df["sm_len"] = df.Drug.apply(len)
In [7]:
train
Out[7]:
Drug_ID Drug Y sm_len
0 Drug 1 O=[N+]([O-])c1c2c(c3ccc4cccc5ccc1c3c45)CCCC2 1 44
1 Drug 2 O=c1c2ccccc2c(=O)c2c1ccc1c2[nH]c2c3c(=O)c4cccc... 0 96
2 Drug 3 [N-]=[N+]=CC(=O)NCC(=O)NN 1 25
3 Drug 4 [N-]=[N+]=C1C=NC(=O)NC1=O 1 25
4 Drug 6 CCCCN(CC(O)C1=CC(=[N+]=[N-])C(=O)C=C1)N=O 1 41
... ... ... ... ...
5089 Drug 7568 CCC(CCC(C)C1CCC2C3CC=C4CC(O)CCC4(C)C3CCC12C)C(C)C 0 49
5090 Drug 7587 CCCCCCCCCCCCOCCO 0 16
5091 Drug 7593 CCOP(=S)(CC)Sc1ccccc1 0 21
5092 Drug 7598 C=C(C)C1CC=C(C)C(OC(C)=O)C1 0 27
5093 Drug 7602 CC/N=c1\cc2oc3cc(NCC)c(C)cc3c(-c3ccccc3C(=O)OC... 0 54

5094 rows × 4 columns

Y value distribution

In [8]:
train.vc("Y")
Out[8]:
Y
1 2759
0 2335
In [9]:
valid.vc("Y")
Out[9]:
Y
1 417
0 311
In [10]:
test.vc("Y")
Out[10]:
Y
1 798
0 658

SMILES string length distribution

Although the tokenized sequence will be much shorter

In [11]:
px.histogram(train, x="sm_len")
In [12]:
px.histogram(valid, x="sm_len")
In [13]:
px.histogram(test, x="sm_len")
In [14]:
config.y_mean, config.y_std = train.Y.mean(), train.Y.std()
config.y_mean, config.y_std
Out[14]:
(0.5416175893207695, 0.49831388016187933)

Pretrained models

We're downloading from this pretrained model, which is based on the paper ChemBERTa: Large-Scale Self-Supervised Pretraining for Molecular Property Prediction:

In [15]:
config.pretrained
Out[15]:
'seyonec/ChemBERTA_PubChem1M_shard00_155k'
In [16]:
tokenizer = AutoTokenizer.from_pretrained(config.pretrained, use_fast=True)
In [17]:
model = AutoModel.from_pretrained(config.pretrained)

Dataset

In [18]:
class ToxDataSet(Dataset):
    def __init__(self, df, tokenizer):
        self.df = df
        self.index_col = df.index
        self.tokenizer = tokenizer
        
    def __len__(self):
        return len(self.df)
    
    @staticmethod
    def normalize_y(y):
        return (y - config.y_mean) / config.y_std
    
    @staticmethod
    def denormalzie_y(y):
        return (y * config.y_std) + config.y_mean
    
    def __getitem__(self, idx):
        row = self.df.loc[self.index_col[idx]]
        smiles = row["Drug"]
        y = row["Y"]
        # return x,y tuple
        return smiles, y
    
    def collate_fn(self, rows):
        x,y = zip(*rows)
        return self.tokenizer(list(x), return_tensors="pt", padding="longest")['input_ids'],\
            torch.FloatTensor(list(y))[:, None]
    
    def __repr__(self):
        return f"""ToxDataSet:{len(self.df)} rows
    with tokenizer:{self.tokenizer.name_or_path}"""
In [19]:
train_ds = ToxDataSet(pd.concat([train, valid]).reset_index(drop=True), tokenizer)
valid_ds = ToxDataSet(test, tokenizer)
In [20]:
train_ds, valid_ds
Out[20]:
(ToxDataSet:5822 rows
     with tokenizer:seyonec/ChemBERTA_PubChem1M_shard00_155k,
 ToxDataSet:1456 rows
     with tokenizer:seyonec/ChemBERTA_PubChem1M_shard00_155k)

Test a single case

This is what x and y looks, intuitively

In [21]:
train_ds[5]
Out[21]:
('[N-]=[N+]=CC(=O)OCC(N)C(=O)O', 1)

Lightning

In [22]:
pl.__version__
Out[22]:
'1.1.2'

Data Module

In [23]:
class ToxLDM(pl.LightningDataModule):
    def __init__(self):
        super().__init__()
        
    def train_dataloader(self):
        return DataLoader(
            dataset=train_ds,
            batch_size=config.bs,
            collate_fn=train_ds.collate_fn,
            shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(
            dataset=valid_ds, batch_size=config.bs*4,
            collate_fn=valid_ds.collate_fn, shuffle=False)
In [24]:
tox_ldm = ToxLDM()
x,y = next(iter(tox_ldm.train_dataloader()))
x.shape,y.shape
Out[24]:
(torch.Size([32, 42]), torch.Size([32, 1]))

Lighting Module

In [25]:
class ToxLightning(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.base = model
        self.top = nn.Sequential(
            nn.BatchNorm1d(model.config.hidden_size),
            nn.Linear(model.config.hidden_size, 1)
        )
        self.sigmoid = nn.Sigmoid()
        self.crit = nn.BCEWithLogitsLoss()
        self.acc = pl.metrics.Accuracy()
        self.prec = pl.metrics.Precision(num_classes=1)
        self.rec = pl.metrics.Recall(num_classes=1)
            
    def configure_optimizers(self):
        """
        2 optimizers, 1 for base model, 1 for top layer
        """
        base_opt = torch.optim.Adam(self.base.parameters(), lr=1e-6)
        top_opt = torch.optim.Adam(self.top.parameters(),  lr=1e-3)
        return base_opt, top_opt
    
    def forward(self, x):
        cls_vec = self.base(x).pooler_output
        return self.top(cls_vec)
    
    def training_step(self, batch, batch_idx, optimizer_idx):
        x, y = batch
        y_ = self(x)
        loss = self.crit(y_, y)
        logits = self.sigmoid(y_)
        acc = self.acc(logits, y)
        precision = self.prec(logits, y)
        recall = self.rec(logits, y)
        self.log("train_loss", loss)
        self.log("train_acc", acc)
        self.log("train_prec", precision)
        self.log("train_recall", recall)
        return loss
    
    def validation_step(self,batch, batch_idx):
        x, y = batch
        y_ = self(x)
        loss = self.crit(y_, y)
        logits = self.sigmoid(y_)
        acc = self.acc(logits, y)
        precision = self.prec(logits, y)
        recall = self.rec(logits, y)
        self.log("val_loss", loss)
        self.log("val_acc", acc)
        self.log("val_prec", precision)
        self.log("val_recall", recall)
        return loss
In [26]:
pl_model = ToxLightning(model)
In [27]:
logger = pl.loggers.TensorBoardLogger("/GCI/tensorboard/tox", log_graph=True )
early_stopping = pl.callbacks.EarlyStopping(monitor="val_loss")
In [28]:
trainer = pl.Trainer(
    logger, 
    gpus=1, 
    callbacks=[early_stopping], 
    fast_dev_run=False)
GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
In [29]:
trainer.fit(pl_model,datamodule=tox_ldm)
  | Name    | Type              | Params
----------------------------------------------
0 | base    | RobertaModel      | 83.5 M
1 | top     | Sequential        | 2.3 K 
2 | sigmoid | Sigmoid           | 0     
3 | crit    | BCEWithLogitsLoss | 0     
4 | acc     | Accuracy          | 0     
5 | prec    | Precision         | 0     
6 | rec     | Recall            | 0     
----------------------------------------------
83.5 M    Trainable params
0         Non-trainable params
83.5 M    Total params

Out[29]:
1
In [31]:
pl_model = pl_model.eval()
In [34]:
def infer_smiles(x):
    x = tokenizer(x, return_tensors="pt")['input_ids']
    with torch.no_grad():
        y_ = pl_model.sigmoid(pl_model(x))
    return y_[0].item()
In [51]:
infer_smiles("CCOP(=S)(CC)Sc1ccccc1")
Out[51]:
0.39659035205841064

Inference on ckb drugs

In [41]:
kegg_db = DBs("kegg")

with kegg_db.con.connect() as conn:
    ckb_kegg_drug_match = pd.read_sql("ckb_kegg_drug_match", con = conn)
In [45]:
ckb_kegg_drug_match["ames_score"] = ckb_kegg_drug_match.smiles.apply(
    lambda x:infer_smiles(x) if x else x)
In [47]:
px.histogram(ckb_kegg_drug_match, x = "ames_score")

Save back to db

In [49]:
with kegg_db.con.connect() as conn:
    ckb_kegg_drug_match.to_sql(
        "ckb_kegg_drug_match",
        if_exists="replace",
        index=False,
        con=conn
    )
In [4]:
kegg_db = DBs("kegg")

with kegg_db.con.connect() as conn:
    ckb_kegg_drug_match = pd.read_sql("ckb_kegg_drug_match", con = conn)
In [6]:
ckb_kegg_drug_match.sample(20)
Out[6]:
ckb kegg_id smiles formula kegg_drug_name ames_score
328 Navitoclax D09935 C1N(CCN(C1)c1ccc(cc1)C(=O)NS(=O)(=O)c1ccc(c(c1... C47H55ClF3N5O6S3 Navitoclax (USAN/INN) 0.116984
123 Triptorelin D06247 [C@@H]1(C(=O)N[C@H](C(=O)N[C@H](C(=O)N[C@H](C(... C64H82N18O13 Triptorelin (USAN/INN) 0.298110
408 Polatuzumab vedotin-piiq D10761 None C6670H10317N1745O2087S40 Polatuzumab vedotin (USAN);Polatuzumab vedotin... NaN
218 Parsaclisib D11506 n1cnc2c(c1N)c(nn2[C@H](c1c(c(c(c(c1)Cl)F)[C@@H... C20H22ClFN6O2. HCl Parsaclisib hydrochloride (USAN) 0.818052
321 Osimertinib D10766 c1(nc(ncc1)Nc1cc(c(cc1OC)N(CCN(C)C)C)NC(=O)C=C... C28H33N7O2. CH4SO3 Osimertinib mesylate (USAN);Osimertinib mesila... 0.443737
297 Sonidegib D10729 O1[C@H](CN(C[C@H]1C)c1ncc(cc1)NC(=O)c1cccc(c1C... C26H26F3N3O3. 2H3PO4 Sonidegib phosphate (USAN);Sonidegib diphospha... 0.484098
127 Aspirin D11804 None None Aspirin and vonoprazan;Cabpirin (TN) NaN
269 Vanucizumab D11244 None C6529H10033N1733O2038S46 Vanucizumab (USAN/INN) NaN
421 Durvalumab D10808 None None Durvalumab (USAN/INN);Durvalumab (genetical re... NaN
180 Aldesleukin D11669 None C690H1113N177O203S6. (C19H16N2O4(C2H4O)2n)6 Bempegaldesleukin (USAN/INN) NaN
289 Raloxifene D02217 Cl.c1(ccc(cc1)OCCN1CCCCC1)C(=O)c1c(sc2c1ccc(c2... C28H27NO4S. HCl Raloxifene hydrochloride (JAN/USP);LY 156758;E... 0.437854
104 Taladegib D10671 c1(cc(c(cc1)C(=O)N(C1CCN(CC1)c1c2c(c(nn1)c1ccn... C26H24F4N6O Taladegib (USAN/INN) 0.474562
176 Vopratelimab D11433 None C6488H9986N1726O2014S44 Vopratelimab (USAN/INN) NaN
330 Tiragolumab D11482 None C6620H10206N1742O2074S40 Tiragolumab (USAN/INN) NaN
189 Defactinib D10619 N(c1nccnc1CNc1nc(ncc1C(F)(F)F)Nc1ccc(cc1)C(=O)... C20H21F3N8O3S. HCl Defactinib hydrochloride (USAN) 0.290324
266 Relatlimab D11350 None C6472H9922N1710O2024S38 Relatlimab (USAN) NaN
208 Diclofenac D00903 [K+].c1cccc(c1Nc1c(cccc1Cl)Cl)CC(=O)[O-] C14H10Cl2NO2. K Diclofenac potassium (USP);Cambia (TN);Catafla... 0.766118
406 Erlotinib D04023 c1nc(c2c(n1)cc(c(c2)OCCOC)OCCOC)Nc1ccc(cc1)C#C.Cl C22H23N3O4. HCl Erlotinib hydrochloride (JAN/USAN);Tarceva (TN) 0.899457
109 Valproic Acid D00399 C(CCC)(CCC)C(=O)O C8H16O2 Valproic acid (USP);Depakene (TN) 0.020188
78 Icotinib D11379 C1(CC(OC(C1)(C)C)(C)C)c1ccc(c(n1)C1=CCC(CC1)(C... C27H35N5O2. HCl Edicotinib hydrochloride (USAN) 0.237642

AMES score distribution for drugs in Genomicare CKB

In [13]:
px.histogram(ckb_kegg_drug_match.query("smiles==smiles"), x="ames_score")

image.png

image.png

image.png